#ifndef AHLGREN_SUBSTITUTIONS
#define AHLGREN_SUBSTITUTIONS

#include <map> // substitution function
#include <utility> // pair
#include <set> // for subst domain
#include "global.h"

namespace lp {
	using namespace std;

	class Functor;

	// ********** Subst ********** //

	class Subst {
	public:
		// Substitution Function Definition. bool is true iff the f is dynamically allocated
		typedef map<id_type,pair<const Functor*,bool> > subst_fun;
		typedef subst_fun::value_type subst_type;
		typedef subst_fun::const_iterator const_iterator;
		typedef subst_fun::iterator iterator;
		typedef subst_fun::size_type size_type;
		// Construct empty substitution
		Subst() { }
		// Copy Constructor
		Subst(const Subst& s) { copy(s); }
		// Construct substitution from only variable names in [beg,end]
		template <class Iter> Subst(const Subst& s, Iter beg, Iter end);
		Subst(const Subst& s, id_type var);
		// Assignment Operator
		Subst& operator=(const Subst&);
		// Move Constructor/assignment
		Subst(Subst&& s) { swap(s); deepen(); }
		Subst& operator=(Subst&& s);
		// Destructor
		~Subst() { clear(); }
		// Clear all bindings in substitution
		void clear();
		// Get size
		bool is_empty() const { return subs.empty(); }
		size_type size() const { return subs.size(); }

		// Comparisons (does NOT expand)
		bool operator==(const Subst& s) const;
		bool operator!=(const Subst& s) const { return !(*this == s); }
		// Is substitution ground?
		bool is_ground() const;
		// Swap
		void swap(Subst& s) { subs.swap(s.subs); }
		// Make deep copies of shallow ones
		void deepen();
		// Make deep copy of variable s
		bool deepen(id_type s);
		// Copy construtor/assignment operator helper
		void copy(const Subst&);
		// Shallow Copy construtor/assignment operator helper
		void shallow_copy(const Subst&);
		// Add substitution (soft pointer)
		void soft_add(id_type s, const Functor* expr);
		bool erase(id_type s);
		iterator erase(iterator i);
		void erase_anonymous();
		// Try to add (soft pointer) substitution, return false iff not already present
		bool add_if(id_type s, const Functor* expr);
		// Steal substitution
		void steal(id_type var, const Functor* expr);
		bool steal_if(id_type var, const Functor* expr);
		// Get substitution
		const Functor* get(id_type s) const;
		// Create unifier from substitution pointers
		void create_unifier();
		void create_match();
		// Restrict substitution to the following range of variables (must be expanded first)
		template <typename Iterator> void restrict(Iterator first, Iterator end);
		// Print
		void print(ostream& os) const;
		// Get domain of variables
		set<id_type> domain() const;
		// Expand substitution
		void expand();
		// Expand substitution with respect to only [beg,end[, erase others
		template <typename Iter> void expand(Iter first, Iter end);
		// Get expanded copy
		Subst get_expanded() const;
		// Get expanded copy with respect to only one variable
		Subst get_expanded(id_type s) const;
		// Get expanded copy with respect to variable [beg,end[
		template <typename Iter> Subst get_expanded(Iter beg, Iter end) const;
		// Iterators
		const_iterator begin() const { return subs.begin(); }
		const_iterator end() const { return subs.end(); }
		iterator begin() { return subs.begin(); }
		iterator end() { return subs.end(); }
	protected:
		subst_fun subs; // substitution function
		// Helpers for creating a unifier (expand all substitutions)
		Functor* expand(const Functor* f) const;
		// Move-assign subs without deepening/expanding (more efficient than operator=(Sust&&)
		void assign(Subst&&);
	};

} // namespace lp
#include "functor.h"
namespace lp {
	
	inline ostream& operator<<(std::ostream& os, const lp::Subst& subs) { subs.print(os); return os; }

	inline Subst::Subst(const Subst& s, id_type var)
	{
		auto at = s.subs.find(var);
		if (at != s.subs.end()) {
			// subs[var] = make_pair(s.expand(at->second.first),true);
			subs.insert(make_pair(var,make_pair(s.expand(at->second.first),true)));
		}
	}

	inline void Subst::assign(Subst&& s)
	{
		clear();
		subs = move(s.subs);
	}

	inline Subst& Subst::operator=(const Subst& subs)
	{
		if (this != &subs) {
			clear();
			copy(subs);
		}
		return *this;
	}

	inline bool Subst::is_ground() const
	{
		for (auto i = begin(); i != end(); ++i) {
			if (!i->second.first->is_ground()) return false;
		}
		return true;
	}

	inline bool Subst::deepen(id_type v)
	{
		auto i = subs.find(v);
		if (i == subs.end()) return false;
		if (!i->second.second) {
			i->second = make_pair(i->second.first->copy(),true);
		}
		return true;
	}

	inline void Subst::deepen()
	{
		for (iterator i = subs.begin(); i != subs.end(); ++i) {
			if (!i->second.second) {
				i->second = make_pair(i->second.first->copy(),true);
			}
		}
	}

	
	inline void Subst::copy(const Subst& subs)
	{
		for (const_iterator i = subs.subs.begin(); i != subs.subs.end(); ++i) {
			// this->subs[i->first] = make_pair(i->second.first->copy(),true);
			this->subs.insert(make_pair(i->first,make_pair(i->second.first->copy(),true)));
		}
	}

	
	inline void Subst::shallow_copy(const Subst& subs)
	{
		for (const_iterator i = subs.subs.begin(); i != subs.subs.end(); ++i) {
			// this->subs[i->first] = make_pair(i->second.first,false);
			this->subs.insert(make_pair(i->first,make_pair(i->second.first,false)));
		}
	}

	
	inline void Subst::clear()
	{
		for (iterator i = subs.begin(); i != subs.end(); ++i) {
			if (i->second.second) delete i->second.first;
		}
		subs.clear();
	}

	
	inline const Functor* Subst::get(id_type s) const
	{
		const_iterator ci = subs.find(s);
		return (ci == subs.end()) ? nullptr : ci->second.first;
	}

	inline Subst Subst::get_expanded(id_type varname) const
	{
		Subst esubs;
		auto at = subs.find(varname);
		if (at != subs.end()) {
			// esubs.subs[varname] = make_pair(expand(at->second.first),true);
			esubs.subs.insert(make_pair(varname,make_pair(expand(at->second.first),true)));
		}
		return esubs;
	}

	template <typename Iter>
	inline Subst Subst::get_expanded(Iter beg, Iter end) const
	{
		Subst esubs;
		for ( ; beg != end; ++beg) {
			auto at = subs.find(*beg);
			if (at != subs.end()) {
				esubs.subs.insert(make_pair(*beg,make_pair(expand(at->second.first),true)));
			}
		}
		return esubs;
	}

	inline Subst Subst::get_expanded() const
	{
		Subst esubs;
		for (auto i = subs.begin(); i != subs.end(); ++i) {
			// esubs.subs[i->first] = make_pair(expand(i->second.first),true);
			esubs.subs.insert(make_pair(i->first,make_pair(expand(i->second.first),true)));
		}
		return esubs;
	}

	inline set<id_type> Subst::domain() const
	{
		set<id_type> vars;
		for (const_iterator ci = subs.begin(); ci != subs.end(); ++ci)
			vars.insert(ci->first);
		return vars;
	}


	template <class Iter>
	Subst::Subst(const Subst& s, Iter beg, Iter end)
	{
		for ( ; beg != end; ++beg) {
			auto at = s.subs.find(*beg);
			if (at != s.subs.end()) {
				subs.insert(make_pair(*beg,make_pair(s.expand(at->second.first),true)));
				//subs[*beg] = make_pair(s.expand(at->second.first),true);
			}
		}
	}


	template <typename Iterator>
	void Subst::restrict(Iterator first, Iterator end)
	{
		iterator i = subs.begin();
		while (i != subs.end()) {
			if (find(first,end,i->first) == end) {
				// i->first not found, erase
				iterator tmp = i;
				++tmp;
				subs.erase(i);
				i = tmp;
			} else ++i; // move to next i
		}
	}

	// Calling operator=(Subst&&) makes an unnecessary call to deepen()
	inline void Subst::expand() 
	{ 
		assign(get_expanded()); 
	}

	template <typename Iter>
	void Subst::expand(Iter curr, Iter end)
	{
		assign(get_expanded(curr,end));
	}


	inline void Subst::soft_add(id_type s, const Functor* expr) { subs.insert(make_pair(s,make_pair(expr,false))); }

	inline bool Subst::erase(id_type s) {
		auto at = subs.find(s);
		if (at == subs.end()) return false;
		this->erase(at);
		return true;
	}

	inline auto Subst::erase(iterator i) -> iterator {
		if (i->second.second) delete i->second.first;
		return subs.erase(i);
	}

	inline void Subst::erase_anonymous() {
		for (auto i = subs.begin(); i != subs.end(); ) {
			// Erase _G = term
			// Note: we can't erase X -> _G since we may have Y -> _G or Z -> f(_G)
			if (functor_map::is_reserved_var(i->first)) {
				i = this->erase(i);
			} else if (i->second.first->is_reserved_variable()) {
				// X = _G and no other binding to _G => remove X (trivial)
				// X=_G, Y=_G, Z=_G, => X = Y, Y = Z
				const id_type var = i->second.first->id();
				auto at = std::find_if(std::next(i),subs.end(),[&](const subst_type& p){ 
					return p.second.first->id() == var; 
				});
				if (at == subs.end()) {
					// Not found, so X = _G is the only binding to _G
					i = this->erase(i);
				} else {
					// X=_G and Y=_G, so X=Y
					if (i->second.second) delete i->second.first;
					i->second.first = new Functor(at->first);
					i->second.second = true;
					this->erase(at);
					++i;
				}
			} else ++i;
		}
	}

	inline bool Subst::add_if(id_type s, const Functor* expr) {
		auto i = subs.find(s);
		if (i == subs.end()) { soft_add(s,expr); return true; }
		else if (*i->second.first == *expr) return true; // already present
		else return false;
	}

	inline void Subst::steal(id_type var, const Functor* expr) { subs.insert(make_pair(var,make_pair(expr,true))); }

	inline bool Subst::steal_if(id_type var, const Functor* expr) {
		auto i = subs.find(var);
		if (i == subs.end()) { steal(var,expr); return true; }
		else if (*i->second.first == *expr) { delete expr; return true; } // already present
		else return false;
	}

	inline Subst& Subst::operator=(Subst&& s) 
	{ 
		if (this != &s) { 
			clear();
			subs = move(s.subs);
			deepen(); 
		} 
		return *this; 
	}

	inline void Subst::create_unifier()
	{
		// Follow substitutions, so e.g. {x->y,y->z} becomes {x->z,y->z}
		subst_fun subs_tmp;
		for (iterator i = subs.begin(); i != subs.end(); ++i) {
			subs_tmp.insert(make_pair(i->first,make_pair(expand(i->second.first),true)));
		}
		clear();
		subs = move(subs_tmp);
	}

	inline void Subst::create_match()
	{
		// Simply make a copy
		for (iterator i = subs.begin(); i != subs.end(); ++i) {
			if (i->second.second) continue; // skip if already allocated
			i->second.first = i->second.first->copy();
			i->second.second = true; // tag as dynamically allocated
		}
	}

} // namespace lp

// Overload swap for Subst
namespace std {
	
	inline void swap(lp::Subst& s1, lp::Subst& s2) { s1.swap(s2); }
}


#endif
